import gi
import os
import argparse
import numpy as np
import setproctitle
import cv2
import hailo
import datetime

gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
from hailo_apps_infra.hailo_rpi_common import (
    get_default_parser,
    detect_hailo_arch,
)
from hailo_apps_infra.gstreamer_helper_pipelines import (
    QUEUE,
    SOURCE_PIPELINE,
    INFERENCE_PIPELINE,
    INFERENCE_PIPELINE_WRAPPER,
    USER_CALLBACK_PIPELINE,
    TRACKER_PIPELINE,
    DISPLAY_PIPELINE,
)
from hailo_apps_infra.gstreamer_app import (
    GStreamerApp,
    app_callback_class,
)

# Object width dictionary for distance estimation (in meters)
OBJECT_WIDTHS = {
    "person": 0.4, "bicycle": 0.5, "car": 1.8, "motorcycle": 0.8, "bus": 2.5,
    "truck": 2.5, "airplane": 36.0, "train": 3.2, "boat": 5.0, "traffic light": 0.6,
    "fire hydrant": 0.3, "stop sign": 0.75, "cat": 0.3, "dog": 0.6, "horse": 1.2,
    "cow": 1.5, "elephant": 3.2, "bear": 1.7, "zebra": 1.2, "giraffe": 2.0,
    "bench": 1.2, "chair": 0.6, "couch": 2.0, "dining table": 1.8,
    "laptop": 0.4, "tv": 1.2
}

DEFAULT_OBJECT_WIDTH = 0.5  # Default width for unknown objects
FOCAL_LENGTH = 0.5  # in meters as per Raspberry Pi Camera Module 3 Wide specifications

# Object tracking dictionary
object_tracker = {}
IOU_THRESHOLD = 0.5  # Threshold to determine if an object is the same across frames

# User callback class
class user_app_callback_class(app_callback_class):
    def __init__(self):
        super().__init__()
        self.frame_skip = 2

# Function to get object width
def get_object_width(obj_name):
    return OBJECT_WIDTHS.get(obj_name, DEFAULT_OBJECT_WIDTH)

# Function to estimate distance
def estimate_distance(obj_name, perceived_width):
    real_width = get_object_width(obj_name)
    if perceived_width > 0:
        return round((FOCAL_LENGTH * real_width) / perceived_width, 2)  # Distance in meters
    return None

# Function to compute Intersection over Union (IoU) between two bounding boxes
def iou(boxA, boxB):
    xA = max(boxA.xmin(), boxB.xmin())
    yA = max(boxA.ymin(), boxB.ymin())
    xB = min(boxA.xmax(), boxB.xmax())
    yB = min(boxA.ymax(), boxB.ymax())

    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (boxA.xmax() - boxA.xmin()) * (boxA.ymax() - boxA.ymin())
    boxBArea = (boxB.xmax() - boxB.xmin()) * (boxB.ymax() - boxB.ymin())

    return interArea / float(boxAArea + boxBArea - interArea)

# Assign a unique ID to each object
def assign_id_to_object(detection, object_tracker):
    bbox = detection.get_bbox()
    obj_name = detection.get_label()

    # Try to match with existing tracked objects
    for obj_id, (prev_name, prev_bbox) in object_tracker.items():
        if obj_name == prev_name and iou(bbox, prev_bbox) > IOU_THRESHOLD:
            object_tracker[obj_id] = (obj_name, bbox)  # Update tracked object
            return obj_id  # Return the matched ID

    # If no match, assign a new ID
    new_id = len(object_tracker) + 1
    object_tracker[new_id] = (obj_name, bbox)
    return new_id

# GStreamer Instance Segmentation App
class GStreamerInstanceSegmentationApp(GStreamerApp):
    def __init__(self, app_callback, user_data):
        parser = get_default_parser()
        args = parser.parse_args()
        super().__init__(args, user_data)

        self.batch_size = 2 #2 frames will be grouped together and passed to the Hailo model per inference call.
        self.video_width = 640
        self.video_height = 640

        self.arch = args.arch if args.arch else detect_hailo_arch()
        self.hef_path = args.hef_path if args.hef_path else os.path.join(self.current_path, '../resources/yolov5n_seg_h8l_mz.hef')
        self.config_file = os.path.join(self.current_path, '../resources/yolov5n_seg.json')
        self.default_post_process_so = os.path.join(self.current_path, '../resources/libyolov5seg_postprocess.so')
        self.post_function_name = "filter_letterbox"
        self.app_callback = app_callback

        setproctitle.setproctitle("Hailo Instance Segmentation App")
        self.create_pipeline()

    def get_pipeline_string(self):
        source_pipeline = SOURCE_PIPELINE(video_source=self.video_source, video_width=self.video_width, video_height=self.video_height)
        infer_pipeline = INFERENCE_PIPELINE(
            hef_path=self.hef_path,
            post_process_so=self.default_post_process_so,
            post_function_name=self.post_function_name,
            batch_size=self.batch_size,
            config_json=self.config_file,
        )
        infer_pipeline_wrapper = INFERENCE_PIPELINE_WRAPPER(infer_pipeline)
        tracker_pipeline = TRACKER_PIPELINE(class_id=1)
        user_callback_pipeline = USER_CALLBACK_PIPELINE()
        display_pipeline = DISPLAY_PIPELINE(video_sink=self.video_sink, sync=self.sync, show_fps=self.show_fps)

        return (
            f'{source_pipeline} ! '
            f'{infer_pipeline_wrapper} ! '
            f'{tracker_pipeline} ! '
            f'{user_callback_pipeline} ! '
            f'{display_pipeline}'
        )

# Callback function to process detections
def app_callback(pad, info, user_data):
    buffer = info.get_buffer()
    if buffer is None:
        return Gst.PadProbeReturn.OK

    user_data.increment()
    if user_data.get_count() % user_data.frame_skip != 0:
        return Gst.PadProbeReturn.OK

    detections = hailo.get_roi_from_buffer(buffer).get_objects_typed(hailo.HAILO_DETECTION)
    
    for detection in detections:
        obj_name = detection.get_label()
        bbox = detection.get_bbox()
        perceived_width = bbox.width() if bbox else 0
        distance = estimate_distance(obj_name, perceived_width)

        obj_id = assign_id_to_object(detection, object_tracker)

        # Get current date and time
        now = datetime.datetime.now()
        timestamp = now.strftime("%d/%m=%H:%M:%S")

        if distance:
            print(f"{timestamp}:{obj_id}:{obj_name.capitalize()}:{distance}M")
        else:
            print(f"{timestamp} ID: {obj_id} {obj_name.capitalize()} detected, distance unknown.")

    return Gst.PadProbeReturn.OK

if __name__ == "__main__":
    user_data = user_app_callback_class()
    app = GStreamerInstanceSegmentationApp(app_callback, user_data)
    app.run()

